# =============================================================================
# FST THEORY - ULTIMATE NUMERICAL SOLVER WITH NORMALIZATION
# =============================================================================

import numpy as np
import pandas as pd
import zipfile
from io import TextIOWrapper
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.optimize import curve_fit
from scipy.integrate import odeint
import astropy.constants as const
import astropy.units as u
from scipy.interpolate import interp1d
import traceback
import os
from tqdm import tqdm

# =============================================================================
# PHYSICAL CONSTANTS (FST PARAMETERS FROM YOUR THEORY)
# =============================================================================
G = 6.67430e-11
Msun = 1.989e30
kpc_to_m = 3.086e19

# FST Fundamental Parameters from your Bayesian analysis
FST_PARAMS = {
    'c1': 0.51,
    'c2': -0.07, 
    'c3': 0.32,
    'mV': 3.2e-30,  # eV
    'lambda_val': 1.2e14
}

# Convert mV to kg
mV_kg = (FST_PARAMS['mV'] * u.eV).to(u.kg, equivalencies=u.mass_energy()).value

# =============================================================================
# NORMALIZED FST SOLVER - SCALED EQUATIONS FOR NUMERICAL STABILITY
# =============================================================================

def solve_fst_normalized(r, V0, dVdr0=0.0):
    """
    Solve FST field equation with normalized parameters for numerical stability
    d²V/dr² + (2/r)dV/dr - mV²V + (λ/6)V³ = 0
    """
    try:
        # NORMALIZATION FACTORS - CRITICAL FOR NUMERICAL STABILITY
        SCALE_V = 1e25      # Scale for V field (adjust based on V0)
        SCALE_R = 1e20      # Scale for radial distance
        SCALE_LAMBDA = 1e-25 # Scale for lambda parameter
        
        # Scaled parameters
        mV_scaled = mV_kg * SCALE_R
        lambda_scaled = FST_PARAMS['lambda_val'] * SCALE_LAMBDA / (SCALE_V**2)
        V0_scaled = V0 * SCALE_V
        dVdr0_scaled = dVdr0 * SCALE_V * SCALE_R
        
        # Scaled radial grid (log spacing for better resolution)
        r_scaled = r * SCALE_R
        r_min = max(min(r_scaled), 1e-10)
        r_max = max(r_scaled) * 2.0
        r_grid = np.logspace(np.log10(r_min), np.log10(r_max), len(r)*3)
        
        def fst_equations_scaled(y, r_scaled):
            V, dVdr = y
            r_safe = max(r_scaled, 1e-25)
            d2Vdr2 = - (2/r_safe) * dVdr + mV_scaled**2 * V - (lambda_scaled/6) * V**3
            return [dVdr, d2Vdr2]
        
        # Solve scaled equations with increased precision
        sol = odeint(fst_equations_scaled, [V0_scaled, dVdr0_scaled], r_grid, 
                    full_output=0, mxstep=500000, rtol=1e-8, atol=1e-10)
        
        # Reverse scaling
        V_sol = sol[:, 0] / SCALE_V
        dVdr_sol = sol[:, 1] / (SCALE_V * SCALE_R)
        
        # Interpolate back to original r points
        V_interp = interp1d(r_grid/SCALE_R, V_sol, bounds_error=False, fill_value="extrapolate")
        dVdr_interp = interp1d(r_grid/SCALE_R, dVdr_sol, bounds_error=False, fill_value="extrapolate")
        
        return V_interp(r), dVdr_interp(r)
        
    except Exception as e:
        print(f"Normalized solver failed, using analytical approximation: {e}")
        # Robust analytical approximation
        k = np.sqrt(max(mV_kg**2, 1e-40))  # Ensure k is not too small
        V_approx = V0 * np.exp(-k * np.maximum(r, 1e-10))
        dVdr_approx = -k * V0 * np.exp(-k * np.maximum(r, 1e-10))
        return V_approx, dVdr_approx

def fst_velocity_final(r, M_total, r_d, V0=1e-3):
    """Final FST velocity calculation with normalized solver"""
    try:
        # Solve normalized FST equations
        V, dVdr = solve_fst_normalized(r, V0)
        
        # Compute additional acceleration: a_V = -V * dV/dr
        a_V = -V * dVdr
        
        # Newtonian component
        r_safe = np.maximum(r, 1e-10)
        M_enc = M_total * (1 - (1 + r_safe/r_d) * np.exp(-r_safe/r_d))
        v_newton = np.sqrt(G * M_enc * Msun / (r_safe * kpc_to_m)) / 1000
        
        # Convert acceleration to equivalent velocity
        r_m = r_safe * kpc_to_m
        v_fst = np.sqrt(np.abs(a_V) * r_m) / 1000
        
        # Total velocity (quadrature sum) with numerical safety
        v_total = np.sqrt(v_newton**2 + np.maximum(v_fst, 0))
        
        return np.nan_to_num(v_total, nan=v_newton, posinf=v_newton, neginf=v_newton)
        
    except Exception as e:
        print(f"FST velocity calculation error: {e}")
        # Fallback to Newtonian with safety
        return newtonian_velocity(r, M_total, r_d)

# =============================================================================
# THEORETICAL MODELS
# =============================================================================

def newtonian_velocity(r, M_total, r_d):
    """Newtonian velocity from baryonic matter only"""
    with np.errstate(all='ignore'):
        r_safe = np.maximum(r, 1e-10)
        M_enc = M_total * (1 - (1 + r_safe/r_d) * np.exp(-r_safe/r_d))
        v = np.sqrt(G * M_enc * Msun / (r_safe * kpc_to_m)) / 1000
        return np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)

def mond_velocity(r, M_total, r_d, a0=1.2e-10):
    """MOND velocity calculation"""
    v_newt = newtonian_velocity(r, M_total, r_d)
    with np.errstate(all='ignore'):
        g_newt = np.where(v_newt > 0, v_newt**2 / (r * kpc_to_m / 1000), 0)
        g_mond = g_newt / np.sqrt(1 + (a0/np.maximum(g_newt, 1e-30))**2)
        v_mond = np.sqrt(np.maximum(g_mond, 0) * r * kpc_to_m / 1000)
        return np.nan_to_num(v_mond, nan=0.0, posinf=0.0, neginf=0.0)

def lcdm_velocity(r, M_total, r_d, M_halo, c_halo):
    """ΛCDM velocity with NFW halo"""
    v_newt = newtonian_velocity(r, M_total, r_d)
    with np.errstate(all='ignore'):
        r_ratio = np.maximum(r / np.maximum(r_d, 1e-10), 1e-10)
        v_halo = 100 * np.sqrt(M_halo/1e12) * np.sqrt(np.log(1 + c_halo*r_ratio)/r_ratio)
        v_total = np.sqrt(v_newt**2 + v_halo**2)
        return np.nan_to_num(v_total, nan=0.0, posinf=0.0, neginf=0.0)

# =============================================================================
# DATA LOADING AND PROCESSING
# =============================================================================

def load_sparc_data(zip_path):
    """Load SPARC galaxy data with enhanced error handling"""
    galaxies_data = {}
    try:
        with zipfile.ZipFile(zip_path, 'r') as zp:
            rotmod_files = [f for f in zp.namelist() if f.endswith('_rotmod.dat')]
            
            print(f"Found {len(rotmod_files)} galaxy files")
            
            for file_name in tqdm(rotmod_files, desc="Loading galaxies"):
                try:
                    with zp.open(file_name) as f:
                        for encoding in ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1']:
                            try:
                                df = pd.read_csv(TextIOWrapper(f, encoding=encoding),
                                                sep=r'\s+', comment="#", engine='python',
                                                names=["R", "Vobs", "eVobs", "Vgas", "Vdisk", "Vbulge"],
                                                na_values=['NaN', 'nan', '-', ''])
                                # Clean and validate data
                                df = df.dropna()
                                df = df[(df['R'] > 0.1) & (df['Vobs'] > 10) & (df['Vobs'] < 500)]
                                
                                if len(df) >= 6:  # Minimum data points
                                    galaxy_name = file_name.replace('_rotmod.dat', '')
                                    galaxies_data[galaxy_name] = df
                                break
                            except Exception as e:
                                f.seek(0)
                                continue
                except Exception as e:
                    continue
                    
    except Exception as e:
        print(f"Failed to load ZIP file: {e}")
    
    return galaxies_data

def safe_curve_fit(func, xdata, ydata, p0, sigma=None, maxfev=20000):
    """Robust curve fitting with multiple fallback strategies"""
    try:
        # Set reasonable bounds based on parameter types
        bounds = (
            [1e8, 0.1, 1e-5],    # Lower bounds: [M_total, r_d, V0]
            [1e15, 50.0, 1e-1]    # Upper bounds
        )
        
        popt, pcov = curve_fit(func, xdata, ydata, p0=p0, sigma=sigma,
                              maxfev=maxfev, bounds=bounds, ftol=1e-4, xtol=1e-4)
        return popt, pcov, True
    except Exception as e:
        # Try with different initial guesses
        try:
            p0_alt = [p * 2.0 for p in p0]
            popt, pcov = curve_fit(func, xdata, ydata, p0=p0_alt, sigma=sigma,
                                  maxfev=maxfev, bounds=bounds)
            return popt, pcov, True
        except:
            return p0, None, False

# =============================================================================
# GALAXY FITTING
# =============================================================================

def fit_galaxy(galaxy_name, df):
    """Fit galaxy rotation curve with enhanced stability"""
    try:
        # Clean data
        r = df["R"].values.astype(float)
        v_obs = df["Vobs"].values.astype(float)
        v_err = df["eVobs"].values.astype(float)
        
        # Robust data cleaning
        valid_mask = (~np.isnan(r)) & (~np.isnan(v_obs)) & (v_err > 0) & (r > 0.1)
        valid_mask &= (v_obs > 10) & (v_obs < 500)
        
        if np.sum(valid_mask) < 6:
            return None
            
        r, v_obs, v_err = r[valid_mask], v_obs[valid_mask], v_err[valid_mask]
        
        # Estimate initial parameters from physical properties
        v_max = np.max(v_obs)
        r_max = r[np.argmax(v_obs)]
        M_est = min(max((v_max**2 * r_max * kpc_to_m * 1000**2) / (G * Msun), 1e9), 1e13)
        
        results = {'Galaxy': galaxy_name, 'Data_points': len(r)}
        
        # Define models with reasonable parameter bounds
        models = [
            ('newton', lambda r, M, rd: newtonian_velocity(r, M, rd), 
             [M_est, r_max/3]),
            
            ('mond', lambda r, M, rd: mond_velocity(r, M, rd), 
             [M_est, r_max/3]),
            
            ('fst', lambda r, M, rd, V0: fst_velocity_final(r, M, rd, V0), 
             [M_est, r_max/3, 1e-3]),
            
            ('lcdm', lambda r, M, rd, Mh, ch: lcdm_velocity(r, M, rd, Mh, ch), 
             [M_est, r_max/3, 1e12, 10])
        ]
        
        for model_name, model_func, p0 in models:
            try:
                popt, pcov, success = safe_curve_fit(model_func, r, v_obs, p0, v_err)
                
                if success:
                    v_pred = model_func(r, *popt)
                    chi2 = np.sum(((v_obs - v_pred) / v_err)**2)
                    dof = len(r) - len(p0)
                    results[f'χ²_{model_name}'] = chi2 / max(dof, 1)
                else:
                    results[f'χ²_{model_name}'] = 999.0
                    
                results[f'success_{model_name}'] = success
                
            except Exception as e:
                results[f'χ²_{model_name}'] = 999.0
                results[f'success_{model_name}'] = False
                
        return results
        
    except Exception as e:
        print(f"Error fitting {galaxy_name}: {e}")
        return None

# =============================================================================
# MAIN ANALYSIS
# =============================================================================

def main():
    """Main analysis function"""
    print("🚀 Starting ULTIMATE FST Analysis with Normalized Solver")
    print("=" * 70)
    print("FST Parameters:", FST_PARAMS)
    print("=" * 70)
    
    # Find data file
    zip_path = "Rotmod_LTG.zip"
    if not os.path.exists(zip_path):
        print(f"❌ Data file not found: {zip_path}")
        print("📁 Please ensure Rotmod_LTG.zip is in the current directory")
        return
    
    # Load data
    galaxies_data = load_sparc_data(zip_path)
    print(f"📊 Loaded {len(galaxies_data)} galaxies")
    
    if len(galaxies_data) == 0:
        print("❌ No valid galaxies found")
        return
    
    # Analyze galaxies with progress tracking
    results = []
    print("\n🔬 Analyzing galaxies with normalized FST solver...")
    
    for galaxy_name, df in tqdm(list(galaxies_data.items())[:137], desc="Analyzing"):  # First 137 for testing
        result = fit_galaxy(galaxy_name, df)
        if result:
            results.append(result)
    
    if results:
        df_results = pd.DataFrame(results)
        df_results.to_csv('fst_analysis_final.csv', index=False)
        
        print(f"\n✅ Successfully analyzed {len(results)} galaxies")
        print("📊 Results saved to fst_analysis_final.csv")
        
        # Calculate and display statistics
        print("\n" + "=" * 50)
        print("FINAL RESULTS SUMMARY")
        print("=" * 50)
        
        for model in ['newton', 'mond', 'fst', 'lcdm']:
            col_name = f'χ²_{model}'
            if col_name in df_results.columns:
                chi2_values = df_results[col_name]
                valid_chi2 = chi2_values[chi2_values < 100]  # Filter out errors
                if len(valid_chi2) > 0:
                    avg_chi2 = np.mean(valid_chi2)
                    print(f"{model.upper():6}: χ² = {avg_chi2:.3f} (n={len(valid_chi2)})")
                else:
                    print(f"{model.upper():6}: No valid fits")
        
        # FST performance analysis
        if 'χ²_fst' in df_results.columns:
            fst_success = df_results['success_fst'].sum()
            print(f"\n🎯 FST Performance: {fst_success}/{len(results)} successful fits")
            
    else:
        print("❌ No results obtained")

if __name__ == "__main__":
    main()